{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn import metrics\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "# depending on the classification model use, we might need to import other packages\n",
    "# from sklearn import svm\n",
    "# from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "from datasets import DatasetUCI\n",
    "from envs import LalEnvTargetAccuracy\n",
    "\n",
    "from helpers import Minibatch, ReplayBuffer\n",
    "from dqn import DQN\n",
    "from Test_AL import policy_rl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup and initialisation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Parameters for dataset and model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "N_STATE_ESTIMATION = 30\n",
    "SIZE = 100\n",
    "# if we want to train and test RL on the same dataset, use even and odd datapoints for training and testing correspondingly\n",
    "SUBSET = -1 # -1 for using all datapoints, 0 for even, 1 for odd\n",
    "N_JOBS = 1 # can set more if we want to parallelise\n",
    "# remove the dataset that will be used for testing\n",
    "# ['australian', 'breast_cancer', 'diabetis', 'flare_solar', 'german', 'heart', 'mushrooms', 'waveform', 'wdbc']\n",
    "possible_dataset_names = ['breast_cancer', 'diabetis', 'flare_solar', \n",
    "                          'german', 'heart', 'mushrooms', 'waveform', 'wdbc']\n",
    "test_dataset_names = ['australian']\n",
    "# The quality is measures according to a given quality measure `quality_method`. \n",
    "QUALITY_METHOD = metrics.accuracy_score\n",
    "# The `tolerance_level` is the proportion of max quality that needs to be achived in order to terminate an episode. \n",
    "TOLERANCE_LEVEL = 0.98"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialise a dataset that will contain a sample of datapoint from one the indicated classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "dataset = DatasetUCI(possible_dataset_names, n_state_estimation=N_STATE_ESTIMATION, subset=SUBSET, size=SIZE)\n",
    "# if we want to measure test error along with training\n",
    "dataset_test = DatasetUCI(test_dataset_names, n_state_estimation=N_STATE_ESTIMATION, subset=1, size=SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialise a model that would be used for training a classifier. <br>\n",
    "It can be, for example, Logistic regression: <br>\n",
    "`LogisticRegression(n_jobs=N_JOBS)` <br>\n",
    "SVM: <br>\n",
    "`svm.SVC(probability=True)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = LogisticRegression(n_jobs=N_JOBS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialise the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "env = LalEnvTargetAccuracy(dataset, model, quality_method=QUALITY_METHOD, tolerance_level=TOLERANCE_LEVEL)\n",
    "env_test = LalEnvTargetAccuracy(dataset_test, model, quality_method=QUALITY_METHOD, tolerance_level=TOLERANCE_LEVEL)\n",
    "tf.reset_default_graph()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Parameters for training RL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DIRNAME = './agents/1-australian-logreg-8-to-1/' # The resulting agent of this experiment will be written in a file\n",
    "\n",
    "# Replay buffer parameters.\n",
    "REPLAY_BUFFER_SIZE = 1e4\n",
    "PRIOROTIZED_REPLAY_EXPONENT = 3\n",
    "\n",
    "# Agent parameters.\n",
    "BATCH_SIZE = 32\n",
    "LEARNING_RATE = 1e-3\n",
    "TARGET_COPY_FACTOR = 0.01\n",
    "BIAS_INITIALIZATION = 0 # default 0 # will be set to minus half of average duration during warm start experiemnts\n",
    "\n",
    "# Warm start parameters.\n",
    "WARM_START_EPISODES = 128 # reduce for test\n",
    "NN_UPDATES_PER_WARM_START = 100\n",
    "\n",
    "# Episode simulation parameters.\n",
    "EPSILON_START = 1\n",
    "EPSILON_END = 0.1\n",
    "EPSILON_STEPS = 1000\n",
    "\n",
    "# Training parameters\n",
    "TRAINING_ITERATIONS = 1000 # reduce for test\n",
    "TRAINING_EPISODES_PER_ITERATION = 10 # at each training ietration x episodes are simulated\n",
    "NN_UPDATES_PER_ITERATION = 60 # at each training iteration x gradient steps are made\n",
    "\n",
    "# Validation and test parameters\n",
    "N_VALIDATION = 500 # reduce for test\n",
    "N_TEST = 500 # reduce for test\n",
    "VALIDATION_TEST_FREQUENCY = 100 # every x iterations val and test are performed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialise replay buffer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "replay_buffer = ReplayBuffer(buffer_size=REPLAY_BUFFER_SIZE, \n",
    "                             prior_exp=PRIOROTIZED_REPLAY_EXPONENT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Warm start"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Warm-start the replay buffer with random episodes. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Collect episodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Keep track of episode duration to compute average\n",
    "episode_durations = []\n",
    "for _ in range(WARM_START_EPISODES):\n",
    "    print('.', end='')\n",
    "    # Reset the environment to start a new episode\n",
    "    # classifier_state contains vector representation of state of the environment (depends on classifier)\n",
    "    # next_action_state contains vector representations of all actions available to be taken at the next step\n",
    "    classifier_state, next_action_state = env.reset()\n",
    "    terminal = False\n",
    "    episode_duration = 0\n",
    "    # before we reach a terminal state, make steps\n",
    "    while not terminal:\n",
    "        # Choose a random action\n",
    "        action = np.random.randint(0, env.n_actions)\n",
    "        # taken_action_state is a vector corresponding to a taken action\n",
    "        taken_action_state = next_action_state[:,action]\n",
    "        next_classifier_state, next_action_state, reward, terminal = env.step(action)\n",
    "        # Store the transition in the replay buffer\n",
    "        replay_buffer.store_transition(classifier_state, \n",
    "                                       taken_action_state, \n",
    "                                       reward, next_classifier_state, \n",
    "                                       next_action_state, terminal)\n",
    "        # Get ready for next step\n",
    "        classifier_state = next_classifier_state\n",
    "        episode_duration += 1 \n",
    "    episode_durations.append(episode_duration)\n",
    "# compute the average episode duration of episodes generated during the warm start procedure\n",
    "av_episode_duration = np.mean(episode_durations)\n",
    "print('Average episode duration = ', av_episode_duration)\n",
    "\n",
    "BIAS_INITIALIZATION = -av_episode_duration/2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the DQN agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "agent = DQN(experiment_dir=DIRNAME,\n",
    "            observation_length=N_STATE_ESTIMATION,\n",
    "            learning_rate=LEARNING_RATE,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            target_copy_factor=TARGET_COPY_FACTOR,\n",
    "            bias_average=BIAS_INITIALIZATION,\n",
    "           )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Do updates of the network based on warm start episodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for _ in range(NN_UPDATES_PER_WARM_START):\n",
    "    print('.', end='')\n",
    "    # Sample a batch from the replay buffer proportionally to the probability of sampling.\n",
    "    minibatch = replay_buffer.sample_minibatch(BATCH_SIZE)\n",
    "    # Use batch to train an agent. Keep track of temporal difference errors during training.\n",
    "    td_error = agent.train(minibatch)\n",
    "    # Update probabilities of sampling each datapoint proportionally to the error.\n",
    "    replay_buffer.update_td_errors(td_error, minibatch.indeces)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train RL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run multiple training iterations. Each iteration consits of:\n",
    "- generating episodes following agent's actions with exploration\n",
    "- validation and test episodes for evaluating performance\n",
    "- Q-network updates\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_episode_rewards = []\n",
    "i_episode = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for iteration in range(TRAINING_ITERATIONS):\n",
    "    # GENERATE NEW EPISODES\n",
    "    # Compute epsilon value according to the schedule.\n",
    "    epsilon = max(EPSILON_END, EPSILON_START-iteration*(EPSILON_START-EPSILON_END)/EPSILON_STEPS)\n",
    "    print(iteration, end=': ')\n",
    "    # Simulate training episodes.\n",
    "    for _ in range(TRAINING_EPISODES_PER_ITERATION):\n",
    "        # Reset the environment to start a new episode.\n",
    "        classifier_state, next_action_state = env.reset()\n",
    "        print(\".\", end='')\n",
    "        terminal = False\n",
    "        # Keep track of stats of episode to analyse it in tensorboard.\n",
    "        episode_reward = 0\n",
    "        episode_duration = 0\n",
    "        episode_summary = tf.Summary()\n",
    "        # Run an episode.\n",
    "        while not terminal:\n",
    "            # Let an agent choose an action.\n",
    "            action = agent.get_action(classifier_state, next_action_state)\n",
    "            # Get a prob of a datapoint corresponding to an action chosen by an agent.\n",
    "            # It is needed just for the tensorboard analysis.\n",
    "            rlchosen_action_state = next_action_state[0,action]\n",
    "            # With epsilon probability, take a random action.\n",
    "            if np.random.ranf() < epsilon: \n",
    "                action = np.random.randint(0, env.n_actions)\n",
    "            # taken_action_state is a vector that corresponds to a taken action\n",
    "            taken_action_state = next_action_state[:,action]\n",
    "            # Make another step.\n",
    "            next_classifier_state, next_action_state, reward, terminal = env.step(action)\n",
    "            # Store a step in replay buffer\n",
    "            replay_buffer.store_transition(classifier_state, \n",
    "                                           taken_action_state, \n",
    "                                           reward, \n",
    "                                           next_classifier_state, \n",
    "                                           next_action_state, \n",
    "                                           terminal)\n",
    "            # Change a state of environment.\n",
    "            classifier_state = next_classifier_state\n",
    "            # Keep track of stats and add summaries to tensorboard.\n",
    "            episode_reward += reward\n",
    "            episode_duration += 1\n",
    "            episode_summary.value.add(simple_value=rlchosen_action_state, \n",
    "                                      tag=\"episode/rlchosen_action_state\")\n",
    "            episode_summary.value.add(simple_value=taken_action_state[0], \n",
    "                                      tag=\"episode/taken_action_state\")\n",
    "        # Add summaries to tensorboard\n",
    "        episode_summary.value.add(simple_value=episode_reward, \n",
    "                                  tag=\"episode/episode_reward\")\n",
    "        episode_summary.value.add(simple_value=episode_duration, \n",
    "                                  tag=\"episode/episode_duration\")\n",
    "        i_episode += 1\n",
    "        agent.summary_writer.add_summary(episode_summary, i_episode)\n",
    "        agent.summary_writer.flush()\n",
    "        \n",
    "    # VALIDATION AND TEST EPISODES\n",
    "    episode_summary = tf.Summary()\n",
    "    if iteration%VALIDATION_TEST_FREQUENCY == 0:\n",
    "        # Validation episodes are run. Use env for it.\n",
    "        all_durations = []\n",
    "        for i in range(N_VALIDATION):\n",
    "            done = False\n",
    "            state, next_action_state = env.reset()\n",
    "            while not(done):\n",
    "                action = policy_rl(agent, state, next_action_state)        \n",
    "                taken_action_state = next_action_state[:,action]\n",
    "                next_state, next_action_state, reward, done = env.step(action)\n",
    "                state = next_state\n",
    "            all_durations.append(len(env.episode_qualities))\n",
    "        episode_summary.value.add(simple_value=np.mean(all_durations), \n",
    "                                  tag=\"episode/train_duration\")\n",
    "        # Test episodes are run. Use env_test for it.\n",
    "        all_durations = []\n",
    "        for i in range(N_TEST):\n",
    "            done = False\n",
    "            state, next_action_state = env_test.reset()\n",
    "            while not(done):\n",
    "                action = policy_rl(agent, state, next_action_state)        \n",
    "                taken_action_state = next_action_state[:,action]\n",
    "                next_state, next_action_state, reward, done = env_test.step(action)\n",
    "                state = next_state\n",
    "            all_durations.append(len(env_test.episode_qualities))\n",
    "        episode_summary.value.add(simple_value=np.mean(all_durations), \n",
    "                                  tag=\"episode/test_duration\")\n",
    "    \n",
    "    episode_summary.value.add(simple_value=epsilon, \n",
    "                              tag=\"episode/epsilon\")\n",
    "    agent.summary_writer.add_summary(episode_summary, iteration)\n",
    "    agent.summary_writer.flush()\n",
    "            \n",
    "    # NEURAL NETWORK UPDATES\n",
    "    for _ in range(NN_UPDATES_PER_ITERATION):\n",
    "        minibatch = replay_buffer.sample_minibatch(BATCH_SIZE)\n",
    "        td_error = agent.train(minibatch)\n",
    "        replay_buffer.update_td_errors(td_error, minibatch.indeces)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### To see the results in tensorboard\n",
    "\n",
    "on the server:\n",
    "tensorboard --logdir=./\n",
    "\n",
    "on the computer:\n",
    "ssh -N -f -L localhost:6006:localhost:6006 konyushk@iccvlabsrv20.iccluster.epfl.ch && open http://localhost:6006"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
